In [1]:
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import tqdm as tqdm
import sklearn.metrics 
import hdbscan
import time
import random
import datetime
%matplotlib inline  
from datetime import datetime, timedelta

Clustering from 2D data down to discrete units

Load in data

In [2]:
recon = np.load('/mnt/cube/tsainbur/github_repos/ModelComparisonProject/data/2D_syllable_coordinates/z_values_ST_syllables_silence_just_B1114.npz')
In [3]:
syllable_z = recon['recon_z']
syllable_time = recon['recon_time']
bird_name = recon['recon_name']
specs = recon['all_x']
recon_length = recon['recon_length']

recon_folder = recon['recon_folder'] 
recon_t_rel_wav = recon['recon_t_rel_wav']
In [4]:
syllable_time = np.array([datetime.strptime(i, "%d/%m/%y %H:%M:%S.%f") for i in syllable_time])
In [5]:
radius = np.array([np.linalg.norm([0,0]-syllable_z[i]) for i in range(len(syllable_z))])
syllable_z_log = np.log(radius)[:,None] / radius[:,None] * syllable_z
In [6]:
BirdData = pd.DataFrame({
        'specs':specs.tolist(), 
        'syllable_z_log':syllable_z_log.tolist(),
        'syllable_time':syllable_time.tolist(),
        'bird_name':bird_name.tolist(),
        'recon_length': recon_length.tolist(),
        'recon_folder': recon_folder.tolist(),
        'recon_t_rel_wav': recon_t_rel_wav.tolist(),
        
    })
In [7]:
BirdData[0:3]
Out[7]:
bird_name recon_folder recon_length recon_t_rel_wav specs syllable_time syllable_z_log
0 b1114_phys /mnt/cube/earneodo/bci_zf/ss_data/b1114/day-mo... 1.318222 5654.049805 [0.0, 0.0, 0.00147247314453, 0.171780452132, 0... 2000-01-01 10:34:14.049676 [-5.7796285596, 0.327556844158]
1 b1114 /mnt/cube/earneodo/bci_zf/raw_data/b1114/011/2... 1.604686 3.010195 [0.0, -5.96311194867e-18, 0.00812097731978, 0.... 2017-02-21 15:51:25.010195 [-2.51219306028, 5.222876706]
2 b1114 /mnt/cube/earneodo/bci_zf/raw_data/b1114/008/2... 1.006271 116.206177 [0.0, 0.0, 0.0, 0.0804443359375, 0.12722791731... 2017-02-20 20:03:01.206173 [-4.50431710166, 0.942068663067]
In [8]:
#test = BirdData[BirdData['bird_name'] == 'b1114'].sort_values(['sequence_num','sequence_syllable'])
#test[0:50]
In [9]:
print "Total Data (Hours): ",np.sum(BirdData['recon_length'])/60/60
Total Data (Hours):  1.74390013748

Remove short syllables from analysis

In [10]:
length_cutoff = 0.15
In [11]:
fig, ax= plt.subplots(nrows=1,ncols=1,figsize=(16,4))
_ = plt.hist(BirdData['recon_length'],bins=160)
plt.title('Distribution of syllable lengths (seconds)')
plt.axvline(x=length_cutoff, ymin=0, ymax = 1000, color='red')
Out[11]:
<matplotlib.lines.Line2D at 0x7f1013b5ef90>
In [12]:
BirdData = BirdData[BirdData['recon_length'] > length_cutoff]
In [13]:
[(bird, np.sum(BirdData['bird_name'] == bird)) for bird in np.unique(BirdData['bird_name'])]
Out[13]:
[('b1114', 6615), ('b1114_phys', 4041)]
In [14]:
BirdData.index = range(len(BirdData))

Break dataset into sequences

  • for each bird label the day of sequence, the sequence number, and the number within the sequence
In [15]:
def split_seq_by_time(times, idxs, max_timedelta = 30):
    idxs_sorted = idxs[times.argsort()]
    times.sort()
    time_before = np.concatenate(
        ([0.],[(times[i] - times[i-1])/np.timedelta64(1, 's')
          for i in np.arange(1,len(times))]))
    sequence_breaks = np.unique(np.concatenate((
                np.where(time_before > max_timedelta)[0], np.array([0,len(times)]))))
    idx_seqs = [idxs_sorted[sequence_breaks[i]:sequence_breaks[i+1]] for 
           i in range(len(sequence_breaks[:-1]))]
    return idx_seqs
In [16]:
# maximum amount of time allowed to pass before considering this bout new
max_timedelta = 10.
In [17]:
# label by day, and sequence within day
BirdData['sequence_num'] = -2
BirdData['day_num'] = -2
BirdData['sequence_syllable'] = -2

seq_lengths = []

all_dates = [i.date() for i in BirdData['syllable_time']]
seq_num_tot = 0
for bird in np.unique(bird_name):   
    """if bird == 'b1114':
        break
    else:
        continue"""
    #For each bird label the day
    bird_dates = [i.date() for i in BirdData[BirdData['bird_name'] == bird]['syllable_time']]
    for i,date in enumerate(tqdm.tqdm(np.unique(bird_dates))):
        #BirdData.loc[np.array((BirdData['bird_name'] == bird) & (np.array(all_dates) == date)), 'day_num'] = i
        BirdData.loc[((BirdData['bird_name'] == bird) & (np.array(date) == all_dates)), 'day_num'] = i
    
    # For each bird label the sequence number
    bird_times = BirdData[BirdData['bird_name']==bird]['syllable_time']
    idx_seqs = split_seq_by_time(np.array(bird_times.values),
                                 np.array(bird_times.index),
                                 max_timedelta=max_timedelta)
    for seq_i, idxs in tqdm.tqdm(enumerate(idx_seqs)):
        seq_lengths.append(len(idxs))
        BirdData.loc[idxs, 'sequence_num'] = seq_num_tot
        seq_num_tot+=1
        
        # Label the syllable number
        BirdData.loc[BirdData.loc[idxs].sort_values('syllable_time').index, 'sequence_syllable'] = np.arange(len(idxs))
            
 
    print bird
100%|██████████| 4/4 [00:00<00:00, 57.93it/s]
18it [00:00, 207.39it/s]
100%|██████████| 2/2 [00:00<00:00, 57.22it/s]
0it [00:00, ?it/s]
b1114
164it [00:00, 289.34it/s]
b1114_phys

In [18]:
BirdData[0:3]
Out[18]:
bird_name recon_folder recon_length recon_t_rel_wav specs syllable_time syllable_z_log sequence_num day_num sequence_syllable
0 b1114_phys /mnt/cube/earneodo/bci_zf/ss_data/b1114/day-mo... 1.318222 5654.049805 [0.0, 0.0, 0.00147247314453, 0.171780452132, 0... 2000-01-01 10:34:14.049676 [-5.7796285596, 0.327556844158] 69 0 1
1 b1114 /mnt/cube/earneodo/bci_zf/raw_data/b1114/011/2... 1.604686 3.010195 [0.0, -5.96311194867e-18, 0.00812097731978, 0.... 2017-02-21 15:51:25.010195 [-2.51219306028, 5.222876706] 12 2 139
2 b1114 /mnt/cube/earneodo/bci_zf/raw_data/b1114/008/2... 1.006271 116.206177 [0.0, 0.0, 0.0, 0.0804443359375, 0.12722791731... 2017-02-20 20:03:01.206173 [-4.50431710166, 0.942068663067] 6 1 12
In [19]:
syllable_seq = BirdData[BirdData['sequence_num'] == BirdData.loc[np.argmax(BirdData['sequence_syllable']), 'sequence_num']]
syllable_seq = syllable_seq.sort_values('sequence_syllable')
In [20]:
np.array(syllable_seq['syllable_time'])
Out[20]:
array(['2017-02-21T15:50:59.389060000', '2017-02-21T15:51:00.739155000',
       '2017-02-21T15:51:02.061786000', ...,
       '2017-02-21T15:54:31.052936000', '2017-02-21T15:54:31.386738000',
       '2017-02-21T15:54:32.018061000'], dtype='datetime64[ns]')
In [21]:
np.array(syllable_seq['syllable_time'])[0]
Out[21]:
numpy.datetime64('2017-02-21T15:50:59.389060000')
In [ ]:
 
In [22]:
bird_times = BirdData[BirdData['bird_name']==bird]['syllable_time']
print bird_times[0:5]
print len(bird_times)
0   2000-01-01 10:34:14.049676
3   2000-01-01 07:43:30.658584
6   2000-01-01 06:33:54.409521
7   2000-01-01 11:00:58.224686
8   2000-01-01 14:50:13.727304
Name: syllable_time, dtype: datetime64[ns]
4041
In [23]:
"""def split_seq_by_time(times, idxs, max_timedelta = 30):
    idxs_sorted = idxs[times.argsort()]
    times.sort()
    time_before = np.concatenate(
        ([0.],[(times[i] - times[i-1])/np.timedelta64(1, 's')
          for i in np.arange(1,len(times))]))
    sequence_breaks = np.unique(np.concatenate((
                np.where(time_before > max_timedelta)[0], np.array([0,len(times)]))))
    print sequence_breaks
    print times[sequence_breaks[:-1]]
    print times[sequence_breaks[:-1]-1]
    print time_before[sequence_breaks[:-1]]
    print len(sequence_breaks), len(time_before), len(times)
    idx_seqs = [idxs_sorted[sequence_breaks[i]:sequence_breaks[i+1]] for 
           i in range(len(sequence_breaks[:-1]))]
    return idx_seqs"""
Out[23]:
"def split_seq_by_time(times, idxs, max_timedelta = 30):\n    idxs_sorted = idxs[times.argsort()]\n    times.sort()\n    time_before = np.concatenate(\n        ([0.],[(times[i] - times[i-1])/np.timedelta64(1, 's')\n          for i in np.arange(1,len(times))]))\n    sequence_breaks = np.unique(np.concatenate((\n                np.where(time_before > max_timedelta)[0], np.array([0,len(times)]))))\n    print sequence_breaks\n    print times[sequence_breaks[:-1]]\n    print times[sequence_breaks[:-1]-1]\n    print time_before[sequence_breaks[:-1]]\n    print len(sequence_breaks), len(time_before), len(times)\n    idx_seqs = [idxs_sorted[sequence_breaks[i]:sequence_breaks[i+1]] for \n           i in range(len(sequence_breaks[:-1]))]\n    return idx_seqs"
In [24]:
"""[ len(i) for i in split_seq_by_time(np.array(bird_times.values),
                                 np.array(bird_times.index),
                                 max_timedelta=.01)]"""
Out[24]:
'[ len(i) for i in split_seq_by_time(np.array(bird_times.values),\n                                 np.array(bird_times.index),\n                                 max_timedelta=.01)]'
In [25]:
fig, ax= plt.subplots(nrows=1,ncols=1,figsize=(16,4))
_ = plt.hist(seq_lengths, bins = 100)
plt.title('Distribution of sequence lengths (in syllables)')
Out[25]:
<matplotlib.text.Text at 0x7f100b8b0450>
In [26]:
BirdData[0:3]
Out[26]:
bird_name recon_folder recon_length recon_t_rel_wav specs syllable_time syllable_z_log sequence_num day_num sequence_syllable
0 b1114_phys /mnt/cube/earneodo/bci_zf/ss_data/b1114/day-mo... 1.318222 5654.049805 [0.0, 0.0, 0.00147247314453, 0.171780452132, 0... 2000-01-01 10:34:14.049676 [-5.7796285596, 0.327556844158] 69 0 1
1 b1114 /mnt/cube/earneodo/bci_zf/raw_data/b1114/011/2... 1.604686 3.010195 [0.0, -5.96311194867e-18, 0.00812097731978, 0.... 2017-02-21 15:51:25.010195 [-2.51219306028, 5.222876706] 12 2 139
2 b1114 /mnt/cube/earneodo/bci_zf/raw_data/b1114/008/2... 1.006271 116.206177 [0.0, 0.0, 0.0, 0.0804443359375, 0.12722791731... 2017-02-20 20:03:01.206173 [-4.50431710166, 0.942068663067] 6 1 12
In [27]:
#BirdData[BirdData['bird_name'] == bird]
In [28]:
[(bird, len(np.unique(BirdData[BirdData['bird_name'] == bird]['sequence_num']))) for bird in np.unique(BirdData['bird_name'])]
Out[28]:
[('b1114', 18), ('b1114_phys', 164)]

Plot densities of clusters

In [29]:
log_z = np.array([i for i in BirdData['syllable_z_log']])
fig, ax= plt.subplots(nrows=1,ncols=1,figsize=(16,16))
ax.scatter(log_z.T[0], log_z.T[1],color='black', alpha = 0.3, linewidth= 0, s=5)
ax.axis('off')
Out[29]:
(-8.0, 6.0, -8.0, 8.0)

Plot by clusters and by bird name

In [30]:
def cluster_data(data, algorithm, args, kwds):
    # Function taken from HDBSCAN python package website
    start_time = time.time()
    labels = algorithm(*args, **kwds).fit_predict(data)
    end_time = time.time()
    print('Clustering took {:.2f} s'.format(end_time - start_time))
    palette = sns.color_palette('husl', np.unique(labels).max() + 1)
    random.shuffle(palette)
    colors = [palette[x] if x >= 0 else (0.75, 0.75, 0.75) for x in labels]
    return labels, palette, colors
In [31]:
# Parameters for clustering
min_cluster_size_hbd = 20 # minimum of N syllables in group to be called a cluster
alpha_hbd= 1.
min_samples_hbd = min_cluster_size_hbd #1
In [32]:
syllable_labels, palette, colors = cluster_data(
    data = log_z, 
    algorithm = hdbscan.HDBSCAN,
    args=(),
    kwds={'min_cluster_size':min_cluster_size_hbd,'alpha':alpha_hbd, 'min_samples':min_samples_hbd }
)
Clustering took 0.28 s
In [33]:
BirdData['syllable_labels'] = syllable_labels
BirdData['old_labels'] = syllable_labels
In [34]:
import copy
old_labels = copy.deepcopy(syllable_labels)
colors_old = [palette[x] if x >= 0 else (0.75, 0.75, 0.75) for x in np.array(BirdData['old_labels'])]
In [35]:
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(32,32))
#fig.suptitle('bold figure suptitle', fontsize=14, fontweight='bold')

plot_kwds = {'alpha' : 0.5, 's' : 10, 'linewidths':0}
sns.set_color_codes(palette='dark')

lz = np.array([i for i in BirdData['syllable_z_log'].values])

ax.scatter(lz.T[0], lz.T[1], color=colors_old, **plot_kwds)

#for i in np.unique(old_labels):
#    ax.text(np.mean(log_z[old_labels == i].T[0]),
#               np.mean(log_z[old_labels == i].T[1]), i,
#              fontsize=12, fontweight='bold', alpha=0.5)    
ax.axis('off')
Out[35]:
(-8.0, 6.0, -8.0, 8.0)

Add nonclusted data to clusters

In [36]:
from scipy import spatial
In [37]:
#label_means = [np.mean([i for i in BirdData[BirdData['syllable_labels'] == j]['syllable_z_log'].values],axis = 0) for j in np.unique(BirdData['syllable_labels'])][1:]
In [38]:
"""all_labeled_syll = np.array(BirdData[BirdData['syllable_labels'] != -1]['syllable_labels'])
all_labeled_z = BirdData[BirdData['syllable_labels'] != -1]['syllable_z_log']
all_labeled_z = np.array([i[:] for i in all_labeled_z])"""
Out[38]:
"all_labeled_syll = np.array(BirdData[BirdData['syllable_labels'] != -1]['syllable_labels'])\nall_labeled_z = BirdData[BirdData['syllable_labels'] != -1]['syllable_z_log']\nall_labeled_z = np.array([i[:] for i in all_labeled_z])"
In [39]:
#all_means = np.array([np.mean(all_labeled_z[syllable_labels == label],axis=0) for label in np.unique(syllable_labels)])[1:]
#np.unique(syllable_labels)[1:]
In [ ]:
 
In [40]:
"""for i, label in enumerate(tqdm.tqdm(BirdData['syllable_labels'])):
    if label == -1:
        #nearest_neighbor = spatial.KDTree(all_labeled_z).query(BirdData['syllable_z_log'][i])[1]
        #nearest_label = all_labeled_syll[nearest_neighbor]
        nearest_neighbor = spatial.KDTree(label_means).query(BirdData['syllable_z_log'][i])[1]
        BirdData.loc[i,'syllable_labels'] = nearest_neighbor
"""
Out[40]:
"for i, label in enumerate(tqdm.tqdm(BirdData['syllable_labels'])):\n    if label == -1:\n        #nearest_neighbor = spatial.KDTree(all_labeled_z).query(BirdData['syllable_z_log'][i])[1]\n        #nearest_label = all_labeled_syll[nearest_neighbor]\n        nearest_neighbor = spatial.KDTree(label_means).query(BirdData['syllable_z_log'][i])[1]\n        BirdData.loc[i,'syllable_labels'] = nearest_neighbor\n"

Define testing vs training data

In [41]:
pct_in_test = .3
num_sequences = len((np.unique(BirdData['sequence_num'])))
num_seq_in_test = int(pct_in_test*num_sequences)
test_sequences = np.random.choice(num_sequences, num_seq_in_test)
In [42]:
BirdData['Holdout'] = 'Training' # set all data to training
In [43]:
BirdData['Holdout'][BirdData['sequence_num'].isin(test_sequences)] = 'Testing'
/local/home/tsainbur/.conda/envs/tim_tf/lib/python2.7/site-packages/ipykernel/__main__.py:1: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  if __name__ == '__main__':
In [44]:
BirdData.to_pickle('../../../data/Labelled_syllables/ClusteredData_syllables_silence_for_ZEKE.pickle')
#BirdData.to_csv('../../../data/Labelled_syllables/ClusteredData_syllables_silence_test.csv')

Data Visualization

Plot by clusters and by bird name

In [45]:
#np.unique(BirdData['syllable_labels'])
In [46]:
palette = sns.color_palette('husl', np.unique(np.unique(BirdData['syllable_labels'])).max() + 1)
random.shuffle(palette)
colors = [palette[x] if x >= 0 else (0.75, 0.75, 0.75) for x in np.array(BirdData['syllable_labels'])]
In [ ]:
 
In [47]:
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(20,20))
#fig.suptitle('bold figure suptitle', fontsize=14, fontweight='bold')

plot_kwds = {'alpha' : 1.0, 's' : 10, 'linewidths':0}
sns.set_color_codes(palette='dark')

lz = np.array([i for i in BirdData['syllable_z_log'].values])
labs = np.array([i for i in BirdData['syllable_labels'].values])
ax.scatter(lz.T[0], lz.T[1], color=colors, **plot_kwds)

ax.axis('off')
Out[47]:
(-8.0, 6.0, -8.0, 8.0)
In [48]:
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(20,20))

plot_kwds = {'alpha' : 0.25, 's' : 10, 'linewidths':0}

bird_palette = sns.color_palette('deep', len(np.unique(bird_name)) + 1)
sns.palplot(bird_palette)


bird_name = np.array([i for i in BirdData['bird_name'].values])

for i,bird in enumerate(np.unique(bird_name)):
    lz = np.array([j for j in BirdData['syllable_z_log'].values])

    #ax.scatter(lz.T[0], lz.T[1], color=colors, **plot_kwds)
    ax.scatter(lz[bird_name == bird].T[0], lz[bird_name == bird].T[1], color=bird_palette[i], **plot_kwds)
    
ax.axis('off')
Out[48]:
(-8.0, 6.0, -8.0, 8.0)
In [49]:
z_log = np.reshape(np.concatenate(np.array(BirdData['syllable_z_log'])), (len(BirdData['syllable_z_log']),2 ))
In [ ]:
 
In [50]:
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(20,20))

plot_kwds = {'alpha' : 0.5, 's' : 10, 'linewidths':0}

#ax.scatter(lz.T[0], lz.T[1], color=colors, **plot_kwds)
ax.scatter(z_log[:,0], z_log[:,1],
           c=np.log(np.array(BirdData['recon_length'])),cmap='viridis', **plot_kwds)

ax.axis('off')

ax.set_title('Syllable as a function of length')
Out[50]:
<matplotlib.text.Text at 0x7f100b210c10>

View cluster categories spatially

In [ ]:
 
In [51]:
# Make an alpha based color scheme
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.cbook import get_sample_data
from matplotlib.colors import LinearSegmentedColormap

def imscatter(x, y, image, ax=None, zoom=1):
    if ax is None:
        ax = plt.gca()
    try:
        image = plt.imread(image)
    except TypeError:
        # Likely already an array...
        pass
    im = OffsetImage(image, zoom=zoom)
    x, y = np.atleast_1d(x, y)
    artists = []
    for x0, y0 in zip(x, y):
        ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False)
        artists.append(ax.add_artist(ab))
    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()
    return artists
In [52]:
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(32,32))
#fig.suptitle('bold figure suptitle', fontsize=14, fontweight='bold')

plot_kwds = {'alpha' : 0.5, 's' : 10, 'linewidths':0}
sns.set_color_codes(palette='dark')

lz = np.array([i for i in BirdData['syllable_z_log'].values])
labs = np.array([i for i in BirdData['syllable_labels'].values])
specs = np.array([i for i in BirdData['specs'].values])
ax.scatter(lz.T[0], lz.T[1], color=colors, **plot_kwds)

for label in np.unique(np.unique(labs)):
    x0 = np.mean(lz[labs == label].T[0])
    y0 = np.mean(lz[labs == label].T[1])
    img_3d = plt.cm.afmhot(np.flipud(np.reshape(specs[labs == label][0],(32,32))))
    imscatter(x0, y0, img_3d, zoom=1, ax=ax)

    
ax.axis('off')
ax.set_title('Clustered Syllables')
Out[52]:
<matplotlib.text.Text at 0x7f100ac8bf10>

View syllables in same cluster

In [53]:
BirdData[0:3]
Out[53]:
bird_name recon_folder recon_length recon_t_rel_wav specs syllable_time syllable_z_log sequence_num day_num sequence_syllable syllable_labels old_labels Holdout
0 b1114_phys /mnt/cube/earneodo/bci_zf/ss_data/b1114/day-mo... 1.318222 5654.049805 [0.0, 0.0, 0.00147247314453, 0.171780452132, 0... 2000-01-01 10:34:14.049676 [-5.7796285596, 0.327556844158] 69 0 1 -1 -1 Training
1 b1114 /mnt/cube/earneodo/bci_zf/raw_data/b1114/011/2... 1.604686 3.010195 [0.0, -5.96311194867e-18, 0.00812097731978, 0.... 2017-02-21 15:51:25.010195 [-2.51219306028, 5.222876706] 12 2 139 88 88 Testing
2 b1114 /mnt/cube/earneodo/bci_zf/raw_data/b1114/008/2... 1.006271 116.206177 [0.0, 0.0, 0.0, 0.0804443359375, 0.12722791731... 2017-02-20 20:03:01.206173 [-4.50431710166, 0.942068663067] 6 1 12 -1 -1 Training
In [54]:
unique_labels = np.unique(BirdData['syllable_labels'])
In [55]:
len(unique_labels)
Out[55]:
117
In [56]:
num_cats =  len(unique_labels)
num_ex = 20
dim1 = dim2 = 32
In [57]:
canvas = np.zeros((dim1*num_ex, dim2*num_cats))
In [58]:
for ji, j in tqdm.tqdm(enumerate(unique_labels[0:num_cats])):
    specs = BirdData[BirdData['old_labels'] == j]['specs'][:num_ex].values
    for i in range(num_ex):
        if i <= len(specs):
            spec = specs[i]
            spec = np.reshape(spec, (dim1,dim2))
            canvas[i*dim1:i*dim1+dim1,ji*dim2:ji*dim2+dim2] = spec
117it [00:00, 406.02it/s]
In [59]:
## TODO: MAKE CANVAS BY BIRD
In [60]:
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(num_cats, num_ex))
ax.matshow(canvas, aspect='auto',
                    cmap=plt.cm.afmhot, origin='lower')
ax.axis('off')
    
plt.show()
In [61]:
# make an interactive 2D interface, where you click on two things and it turns them into one group
In [62]:
num_cats =  20
num_ex = 20
dim1 = dim2 = 32
In [63]:
canvas = np.zeros((dim1*num_ex, dim2*num_cats))
for ji, j in tqdm.tqdm(enumerate(unique_labels[0:num_cats])):
    specs = BirdData[BirdData['bird_name'] == 'b1080']['specs'][ji*num_ex:((ji+1)*(num_ex))].values
    for i in range(num_ex):
        if i <= len(specs):
            spec = specs[i]
            spec = np.reshape(spec, (dim1,dim2))
            canvas[i*dim1:i*dim1+dim1,ji*dim2:ji*dim2+dim2] = spec
0it [00:00, ?it/s]

IndexErrorTraceback (most recent call last)
<ipython-input-63-ceb180681a73> in <module>()
      4     for i in range(num_ex):
      5         if i <= len(specs):
----> 6             spec = specs[i]
      7             spec = np.reshape(spec, (dim1,dim2))
      8             canvas[i*dim1:i*dim1+dim1,ji*dim2:ji*dim2+dim2] = spec

IndexError: index 0 is out of bounds for axis 0 with size 0
In [ ]:
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(num_cats, num_ex))
ax.matshow(canvas, aspect='auto',
                    cmap=plt.cm.afmhot, origin='lower')
ax.axis('off')
    
plt.show()
In [ ]:
np.shape(specs)
In [ ]:
iii
In [ ]: